{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "no gpu found\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "from wavenet_model import *\n",
    "from audio_data import WavenetDataset\n",
    "from wavenet_training import *\n",
    "from model_logging import *\n",
    "from scipy.io import wavfile\n",
    "\n",
    "dtype = torch.FloatTensor\n",
    "ltype = torch.LongTensor\n",
    "\n",
    "use_cuda = torch.cuda.is_available()\n",
    "if use_cuda:\n",
    "    print('use gpu')\n",
    "    dtype = torch.cuda.FloatTensor\n",
    "    ltype = torch.cuda.LongTensor\n",
    "else: \n",
    "    print(\"no gpu found\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model:  WaveNetModel(\n",
      "  (filter_convs): ModuleList(\n",
      "    (0): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (1): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (2): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (3): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (4): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (5): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (6): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (7): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (8): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (9): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (10): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (11): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (12): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (13): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (14): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (15): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (16): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (17): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (18): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (19): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (20): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (21): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (22): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (23): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (24): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (25): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (26): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (27): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (28): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (29): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (30): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (31): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (32): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (33): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (34): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (35): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (36): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (37): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (38): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (39): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "  )\n",
      "  (gate_convs): ModuleList(\n",
      "    (0): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (1): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (2): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (3): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (4): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (5): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (6): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (7): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (8): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (9): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (10): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (11): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (12): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (13): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (14): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (15): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (16): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (17): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (18): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (19): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (20): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (21): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (22): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (23): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (24): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (25): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (26): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (27): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (28): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (29): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (30): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (31): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (32): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (33): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (34): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (35): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (36): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (37): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (38): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "    (39): Conv1d (32, 32, kernel_size=(2,), stride=(1,), bias=False)\n",
      "  )\n",
      "  (residual_convs): ModuleList(\n",
      "    (0): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (1): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (2): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (3): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (4): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (5): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (6): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (7): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (8): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (9): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (10): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (11): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (12): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (13): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (14): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (15): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (16): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (17): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (18): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (19): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (20): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (21): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (22): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (23): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (24): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (25): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (26): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (27): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (28): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (29): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (30): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (31): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (32): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (33): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (34): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (35): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (36): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (37): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (38): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (39): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "  )\n",
      "  (skip_convs): ModuleList(\n",
      "    (0): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (1): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (2): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (3): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (4): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (5): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (6): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (7): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (8): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (9): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (10): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (11): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (12): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (13): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (14): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (15): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (16): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (17): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (18): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (19): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (20): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (21): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (22): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (23): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (24): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (25): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (26): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (27): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (28): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (29): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (30): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (31): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (32): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (33): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (34): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (35): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (36): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (37): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (38): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "    (39): Conv1d (32, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "  )\n",
      "  (start_conv): Conv1d (1, 32, kernel_size=(1,), stride=(1,), bias=False)\n",
      "  (end_conv): Conv1d (32, 256, kernel_size=(1,), stride=(1,))\n",
      ")\n",
      "receptive field:  4093\n",
      "parameter count:  254240\n"
     ]
    }
   ],
   "source": [
    "model = WaveNetModel(layers=10,\n",
    "                     blocks=4,\n",
    "                     dilation_channels=32,\n",
    "                     residual_channels=32,\n",
    "                     skip_channels=32,\n",
    "                     output_length=64,\n",
    "                     dtype=dtype)\n",
    "#model = load_latest_model_from('snapshots', use_cuda=use_cuda)\n",
    "if use_cuda:\n",
    "    model.cuda()\n",
    "\n",
    "print('model: ', model)\n",
    "print('receptive field: ', model.receptive_field)\n",
    "print('parameter count: ', model.parameter_count())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Variable containing:\n",
      "( 0  ,.,.) = \n",
      "   0   0   0  ...    0   0   0\n",
      "\n",
      "( 1  ,.,.) = \n",
      "   0   0   0  ...    0   0   0\n",
      "\n",
      "( 2  ,.,.) = \n",
      "   0   0   0  ...    0   0   0\n",
      " ... \n",
      "\n",
      "( 29 ,.,.) = \n",
      "   0   0   0  ...    0   0   0\n",
      "\n",
      "( 30 ,.,.) = \n",
      "   0   0   0  ...    0   0   0\n",
      "\n",
      "( 31 ,.,.) = \n",
      "   0   0   0  ...    0   0   0\n",
      "[torch.FloatTensor of size 32x1x4156]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "batch_size = 32\n",
    "input_data = Variable(torch.zeros([batch_size, 1, model.receptive_field + model.output_length - 1]))\n",
    "print(input_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false,
    "scrolled": false
   },
   "outputs": [
    {
    }
   ],
   "source": [
    "with torch.autograd.profiler.profile(enabled=True, use_cuda=True) as prof:\n",
    "    out = model(input_data)\n",
    "    loss = F.cross_entropy(out.squeeze(), Variable(torch.zeros([batch_size * model.output_length]).type(ltype)))\n",
    "    loss.backward()\n",
    "print(prof.key_averages().table(sort_by='cpu_time_total'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "prof.export_chrome_trace('profiling/latest_trace.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "one generating step does take approximately 0.012284069061279297 seconds)\n",
      "---------------  ---------------  ---------------  ---------------  ---------------  ---------------\n",
      "Name                    CPU time        CUDA time            Calls        CPU total       CUDA total\n",
      "---------------  ---------------  ---------------  ---------------  ---------------  ---------------\n",
      "view                     6.082us          0.000us                1          6.082us          0.000us\n",
      "squeeze                  1.857us          0.000us              100        185.680us          0.000us\n",
      "threshold                3.185us          0.000us              100        318.454us          0.000us\n",
      "div                      3.349us          0.000us              100        334.918us          0.000us\n",
      "softmax                  6.903us          0.000us              100        690.307us          0.000us\n",
      "unsqueeze                1.704us          0.000us             4000       6816.902us          0.000us\n",
      "mul                      2.180us          0.000us             4000       8721.320us          0.000us\n",
      "sigmoid                  2.209us          0.000us             4000       8837.406us          0.000us\n",
      "tanh                     2.279us          0.000us             4000       9115.436us          0.000us\n",
      "add                      2.339us          0.000us             8000      18713.918us          0.000us\n",
      "cat                      5.684us          0.000us             3508      19939.675us          0.000us\n",
      "SetItem                 10.601us          0.000us             4000      42403.212us          0.000us\n",
      "Index                    8.634us          0.000us            15408     133035.871us          0.000us\n",
      "ConvForward             16.638us          0.000us            16200     269537.283us          0.000us\n",
      "\n"
     ]
    }
   ],
   "source": [
    "with torch.autograd.profiler.profile() as prof:\n",
    "    model.generate_fast(num_samples=100)\n",
    "print(prof.key_averages().table(sort_by='cpu_time_total'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda root]",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
